{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Automatic Mixed-Precision (AMP)\n",
    "\n",
    "This notebook shows a working code example of how to use AIMET to perform Auto Mixed Precision (AMP). AMP is a technique where given a quantized accuracy target, AIMET finds bit-precision per-layer to meet that accuracy target while trying to optimize the model for inference speed.\n",
    "\n",
    "As an example, say a particular model is not meeting a desired accuracy target when run in INT8. The Auto Mixed Precision feature will find a minimal set of layers that need to run on say INT16 to get to the desired accuracy. It should be noted that choosing higher precision for some layers necessarily involves a trade-off: lower inferences/sec for higher accuracy and vice-versa.\n",
    "\n",
    "Alternatively, the AMP feature can be used to generate a pareto curve (accuracy vs. bit-ops) that can guide the user to decide the right operating point for this tradeoff.\n",
    "\n",
    "This notebook specifically shows working code example for the above. \n",
    "\n",
    "#### Overall flow\n",
    "This notebook covers the following\n",
    "1. Instantiate the example evaluation pipeline\n",
    "2. Convert an FP32 PyTorch model to ONNX and evaluate the model's baseline FP32 accuracy\n",
    "3. Create a quantization simulation model (with fake quantization ops inserted)\n",
    "4. Run AMP algorithm on the quantized model \n",
    "\n",
    "#### What this notebook is not\n",
    "* This notebook is not designed to show state-of-the-art AMP results. For example, it uses a relatively quantization-friendly model like Resnet18. Also, some optimization parameters like number of samples for evaluation are deliberately chosen to have the notebook execute more quickly."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "---\n",
    "## Dataset\n",
    "\n",
    "This notebook relies on the ImageNet dataset for the task of image classification. If you already have a version of the dataset readily available, please use that. Else, please download the dataset from appropriate location (e.g. https://image-net.org/challenges/LSVRC/2012/index.php#).\n",
    "\n",
    "**Note1**: The ImageNet dataset typically has the following characteristics and the dataloader provided in this example notebook rely on these\n",
    "- Subfolders 'train' for the training samples and 'val' for the validation samples. Please see the [pytorch dataset description](https://pytorch.org/vision/0.8/_modules/torchvision/datasets/imagenet.html) for more details.\n",
    "- A subdirectory per class, and a file per each image sample\n",
    "\n",
    "**Note2**: To speed up the execution of this notebook, you may use a reduced subset of the ImageNet dataset. E.g. the entire ILSVRC2012 dataset has 1000 classes, 1000 training samples per class and 50 validation samples per class. But for the purpose of running this notebook, you could perhaps reduce the dataset to say 2 samples per class. This exercise is left upto the reader and is not necessary.\n",
    "\n",
    "Edit the cell below and specify the directory where the downloaded ImageNet dataset is saved."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "DATASET_DIR = '/path/to/dataset'         # Please replace this with a real directory"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "---\n",
    "## 1. Example evaluation pipeline\n",
    "\n",
    "The following is an example validation loop for this image classification task.\n",
    "\n",
    "- **Does AIMET have any limitations on how the validation pipeline is written?** Not really. We will see later that AIMET will modify the user's model and provide a QuantizationSim session that acts as a regular onnxruntime inference session. However, it is recommended that users only use inference sessions created by the QuantizationSimModel, as this will automatically register the required custom operators.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import onnxruntime as ort\n",
    "from Examples.common import image_net_config\n",
    "from Examples.onnx.utils.image_net_evaluator import ImageNetEvaluator\n",
    "from Examples.torch.utils.image_net_data_loader import ImageNetDataLoader\n",
    "\n",
    "class ImageNetDataPipeline:\n",
    "\n",
    "    @staticmethod\n",
    "    def get_val_dataloader() -> torch.utils.data.DataLoader:\n",
    "        \"\"\"\n",
    "        Instantiates a validation dataloader for ImageNet dataset and returns it\n",
    "        \"\"\"\n",
    "        data_loader = ImageNetDataLoader(DATASET_DIR,\n",
    "                                         image_size=image_net_config.dataset['image_size'],\n",
    "                                         batch_size=image_net_config.evaluation['batch_size'],\n",
    "                                         is_training=False,\n",
    "                                         num_workers=image_net_config.evaluation['num_workers']).data_loader\n",
    "        return data_loader\n",
    "\n",
    "    @staticmethod\n",
    "    def evaluate(sess: ort.InferenceSession) -> float:\n",
    "        \"\"\"\n",
    "        Given a torch model, evaluates its Top-1 accuracy on the dataset\n",
    "        :param sess: the model to evaluate\n",
    "        \"\"\"\n",
    "        evaluator = ImageNetEvaluator(DATASET_DIR, image_size=image_net_config.dataset['image_size'],\n",
    "                                      batch_size=image_net_config.evaluation['batch_size'],\n",
    "                                      num_workers=image_net_config.evaluation['num_workers'])\n",
    "\n",
    "        return evaluator.evaluate(sess, iterations=None)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "---\n",
    "\n",
    "## 2. Convert an FP32 PyTorch model to ONNX, simplify & then evaluate baseline FP32 accuracy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "For this example notebook, we are going to load a pretrained resnet18 model from torchvision. Similarly, you can load any pretrained PyTorch model instead."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from torchvision.models import resnet18\n",
    "import onnx\n",
    "\n",
    "input_shape = (1, 3, 224, 224)    # Shape for each ImageNet sample is (3 channels) x (224 height) x (224 width)\n",
    "dummy_input = torch.randn(input_shape)\n",
    "filename = \"./resnet18.onnx\"\n",
    "\n",
    "# Load a pretrained ResNet-18 model in torch\n",
    "pt_model = resnet18(pretrained=True)\n",
    "\n",
    "# Export the torch model to onnx\n",
    "torch.onnx.export(pt_model.eval(),\n",
    "                  dummy_input,\n",
    "                  filename,\n",
    "                  training=torch.onnx.TrainingMode.EVAL,\n",
    "                  export_params=True,\n",
    "                  do_constant_folding=False,\n",
    "                  input_names=['input'],\n",
    "                  output_names=['output'],\n",
    "                  dynamic_axes={\n",
    "                      'input' : {0 : 'batch_size'},\n",
    "                      'output' : {0 : 'batch_size'},\n",
    "                  }\n",
    "                  )\n",
    "\n",
    "model = onnx.load_model(filename)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "It is recommended to simplify the model before using AIMET"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from onnxsim import simplify\n",
    "\n",
    "try:\n",
    "    model, _ = simplify(model)\n",
    "except:\n",
    "    print('ONNX Simplifier failed. Proceeding with unsimplified model')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "---\n",
    "We should decide whether to run the model on a CPU or CUDA device. This example code will use CUDA if available in your onnxruntime environment. You can change this logic and force a device placement if needed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# cudnn_conv_algo_search is fixing it to default to avoid changing in accuracies/outputs at every inference\n",
    "if 'CUDAExecutionProvider' in ort.get_available_providers():\n",
    "    providers = [('CUDAExecutionProvider', {'cudnn_conv_algo_search': 'DEFAULT'}), 'CPUExecutionProvider']\n",
    "else:\n",
    "    providers = ['CPUExecutionProvider']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "---\n",
    "Let's create an onnxruntime session and determine the FP32 (floating point 32-bit) accuracy of this model using the evaluate() routine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "sess = ort.InferenceSession(model.SerializeToString(), providers=providers)\n",
    "accuracy = ImageNetDataPipeline.evaluate(sess)\n",
    "print(accuracy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "---\n",
    "## 3. Create a quantization simulation model\n",
    "\n",
    "## Fold Batch Normalization layers\n",
    "Before we determine the simulated quantized accuracy using QuantizationSimModel, we will fold the BatchNormalization (BN) layers in the model. These layers get folded into adjacent Convolutional layers. The BN layers that cannot be folded are left as they are.\n",
    "\n",
    "**Why do we need to this?**\n",
    "On quantized runtimes (like TFLite, SnapDragon Neural Processing SDK, etc.), it is a common practice to fold the BN layers. Doing so, results in an inferences/sec speedup since unnecessary computation is avoided. Now from a floating point compute perspective, a BN-folded model is mathematically equivalent to a model with BN layers from an inference perspective, and produces the same accuracy. However, folding the BN layers can increase the range of the tensor values for the weight parameters of the adjacent layers. And this can have a negative impact on the quantized accuracy of the model (especially when using INT8 or lower precision). So, we want to simulate that on-target behavior by doing BN folding here.\n",
    "\n",
    "The following code calls AIMET to fold the BN layers in-place on the given model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from aimet_onnx.batch_norm_fold import fold_all_batch_norms_to_weight\n",
    "\n",
    "_ = fold_all_batch_norms_to_weight(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Create Quantization Sim Model\n",
    "\n",
    "Now we use AIMET to create a QuantizationSimModel. This basically means that AIMET will insert fake quantization ops in the model graph and will configure them.\n",
    "\n",
    "A few of the parameters are explained here\n",
    "- **quant_scheme**: We set this to \"QuantScheme.post_training_tf_enhanced\"\n",
    "    - Supported options are 'tf_enhanced' or 'tf' or using Quant Scheme Enum QuantScheme.post_training_tf or QuantScheme.post_training_tf_enhanced\n",
    "- **default_output_bw**: Setting this to 8, essentially means that we are asking AIMET to perform all activation quantizations in the model using integer 8-bit precision\n",
    "- **default_param_bw**: Setting this to 8, essentially means that we are asking AIMET to perform all parameter quantizations in the model using integer 8-bit precision\n",
    "\n",
    "There are other parameters that are set to default values in this example. Please check the AIMET API documentation of QuantizationSimModel to see reference documentation for all the parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from aimet_common.defs import QuantScheme\n",
    "import aimet_onnx\n",
    "from aimet_onnx.quantsim import QuantizationSimModel\n",
    "\n",
    "sim = QuantizationSimModel(model=model,\n",
    "                           quant_scheme=QuantScheme.min_max,\n",
    "                           param_type=aimet_onnx.int8,\n",
    "                           activation_type=aimet_onnx.int8,\n",
    "                           providers=providers)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "---\n",
    "## Compute Encodings\n",
    "\n",
    "Even though AIMET has added 'quantizer' nodes to the model graph but the model is not ready to be used yet. Before we can use the sim model for inference or training, we need to find appropriate scale/offset quantization parameters for each 'quantizer' node. For activation quantization nodes, we need to pass unlabeled data samples through the model to collect range statistics which will then let AIMET calculate appropriate scale/offset quantization parameters. This process is sometimes referred to as calibration. AIMET simply refers to it as 'computing encodings'.\n",
    "\n",
    "It may be beneficial if the samples used for forward pass are well distributed, though it doesn't necessarily mean that all classes need to be covered etc. since we are only looking at the range of values at every layer activation. However, we definitely want to avoid using extremely biased subset of the original dataset, such as a subset consisting of only 'dark' or 'light' images.\n",
    "\n",
    "The following shows an example of a routine that passes unlabeled samples through the model for computing encodings. This routine can be written in many different ways, this is just an example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def pass_calibration_data(session):\n",
    "    data_loader = ImageNetDataPipeline.get_val_dataloader()\n",
    "    batch_size = data_loader.batch_size\n",
    "    input_name = sess.get_inputs()[0].name\n",
    "\n",
    "    batch_cntr = 0\n",
    "    for input_data, _ in data_loader:\n",
    "\n",
    "        inputs_batch = input_data.numpy()\n",
    "        session.run(None, {input_name : inputs_batch})\n",
    "\n",
    "        batch_cntr += 1\n",
    "        # Use 10000 samples for computing initial scale/offset\n",
    "        if (batch_cntr * batch_size) > 1000:\n",
    "            break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "---\n",
    "Now we call AIMET to use the above routine to pass data through the model and then subsequently compute the quantization encodings. Encodings here refer to scale/offset quantization parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "sim.compute_encodings(forward_pass_callback=pass_calibration_data)\n",
    "\n",
    "accuracy = ImageNetDataPipeline.evaluate(sim.session)\n",
    "print(accuracy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 4. Run AMP algorithm  on the quantized model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "AMP algorithm runs in 2 phases. Phase 1 comprises of calculating the sensitivity for each layer and phase 2 comprises of Greedily selecting which layers should have what bitwidth based on options provided by the user. For phase 1 and phase 2 we require to pass data through the model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Define callback functions for AMP\n",
    "\n",
    "AMP requires three callback functions, `forward_pass_callback`, `eval_callback_for_phase1`, and `eval_callback_for_phase2`.\n",
    "\n",
    "`forward_pass_callback` is a CallbackFunc object which is required for computing initial scale/offset as explained above. In this example, we will reuse the same forward function as the previous code snippet."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from aimet_common.defs import CallbackFunc\n",
    "\n",
    "# Use 1000 samples for computing initial scale/offset\n",
    "forward_pass_callback = CallbackFunc(pass_calibration_data, func_callback_args=1000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "`eval_callback_for_phase1` and `eval_callback_for_phase2` are also CallbackFunc objects used for measuring the model's eval score in phase 1 and phase 2, respecitively. Even though they are both used for evaluating model's quality, they have slightly different goals. Eval callback for phase 1 is used to get a rough measure of the model's quality, whereas eval callbak for phase 2 is used for measuring the model's quality in the real practice. This implies that the eval callback for phase 1 can be more flexible than phase 2.\n",
    "\n",
    "For example, to measure your model's quality in phase 1, you can use relatively smaller dataset, or even use an indirect measure (e.g. SQNR between the fp32 outputs and fake-quantized outputs) that can be computed faster than but correlates well with the real metric. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from aimet_torch.amp.mixed_precision_algo import EvalCallbackFactory\n",
    "\n",
    "# Phase 1 evaluation: Evaluate SQNR between fp32 outputs and fake-quantized outputs\n",
    "def forward_one_batch(session, batch):\n",
    "    image, label = batch\n",
    "    \n",
    "    inputs_batch = image.numpy()\n",
    "\n",
    "    input_name = sess.get_inputs()[0].name\n",
    "\n",
    "    return session.run(None, {input_name : inputs_batch})[0]\n",
    "\n",
    "eval_callback_factory =  EvalCallbackFactory(ImageNetDataPipeline.get_val_dataloader(),\n",
    "                                             forward_fn=forward_one_batch)\n",
    "eval_callback_for_phase1 = eval_callback_factory.sqnr()\n",
    "\n",
    "# Alternatively, you can also evaluate the classification accuracy with a small subset of validation dataset\n",
    "###\n",
    "# eval_callback_for_phase1 = CallbackFunc(ImageNetDataPipeline.evaluate, func_callback_args=1000) # Use 1000 samples for phase 1 evaluation\n",
    "###"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "In phase 2, on the other hand, the eval callback should desirably measure the model's quality in the real metric (e.g. accuracy, mIoU, etc.) with full validation dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Phase 2 evaluation: Evaluate accuracy with full dataset\n",
    "eval_callback_for_phase2 = CallbackFunc(ImageNetDataPipeline.evaluate, func_callback_args=None) # Use the full dataset for phase 2 evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Parameters for AMP algorithm\n",
    "\n",
    "A few of the parameters required for AMP are explained below\n",
    "\n",
    "- **sim**: `QuantizationSimModel` object to which mixed precision will be applied.\n",
    "- **dummy_inputs**: Dummy input to the model. If the model has more than one input, pass a tuple. User is expected to place the tensors on the appropriate device.\n",
    "- **eval_callback_for_phase1**: A CallbackFunc object used for measure sensitivity of each quantizer group in phase 1. The phase 1 involves finding accuracy list/sensitivity of each module. Therefore, a user might want to run the phase 1 with a smaller dataset\n",
    "- **eval_callback_for_phase2**: A CallbackFunc object used for evaluating accuracy of quantized model in phase 2. The phase 2 involves finding pareto front curve.\n",
    "- **candidates**: It is a list of tuples for all possible bitwidth values for activations and parameters. Suppose the possible combinations are-((Activation bitwidth - 8, Activation data type - int), (Parameter bitwidth - 16, parameter data type - int)) ((Activation bitwidth - 16, Activation data type - float), (Parameter bitwidth - 16, parameter data type - float)) candidates will be [((8, QuantizationDataType.int), (16, QuantizationDataType.int)), ((16, QuantizationDataType.float), (16, QuantizationDataType.float))]\n",
    "- **allowed_accuracy_drop**: Maximum allowed drop in accuracy from FP32 baseline. The pareto front curve is plotted only till the point where the allowable accuracy drop is met. To get a complete plot for picking points on the curve, the user can set the allowable accuracy drop to None.\n",
    "- **results_dir**: Path to save results and cache intermediate results\n",
    "- **clean_start**: If true, any cached information from previous runs will be deleted prior to starting the mixed-precision analysis. If false, prior cached information will be used if applicable. Note it is the user's responsibility to set this flag to true if anything in the model or quantization parameters changes compared to the previous run.\n",
    "- **use_all_amp_candidates**: Using the “supported_kernels” field in the config file (under defaults and op_type sections), a list of supported candidates can be specified. All the AMP candidates which are passed through the “candidates” field may not be supported based on the data passed through “supported_kernels”. When the field “use_all_amp_candidates” is set to True, the AMP algorithm will ignore the \"supported_kernels\" in the config file and continue to use all candidates.\n",
    "- **forward_pass_callback** A CallbackFunc object which used as forward pass callback for computing quantization encodings.\n",
    "- **amp_search_algo**: An AMPSearchAlgo enum that defines the search algorithm to be used for phase 2. You can choose one of `AMPSearchAlgo.Binary` (default), `AMPSearchAlgo.Interpolation`, and `AMPSearchAlgo.BruteForce`.\n",
    "- **phase1_optimize**: It is a flag for phase 1 implementation which used to choose either optmized or default implementation. If user set this parameter to false then phase1 default logic will be executed else optimized logic will be executed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from aimet_common.defs import QuantizationDataType\n",
    "from aimet_common.amp.utils import AMPSearchAlgo\n",
    "\n",
    "candidates = [\n",
    "    ((16, QuantizationDataType.int), (16, QuantizationDataType.int)), \n",
    "    ((16, QuantizationDataType.int), (8, QuantizationDataType.int)),\n",
    "    ((8, QuantizationDataType.int), (16, QuantizationDataType.int)),\n",
    "    ((8, QuantizationDataType.int), (8, QuantizationDataType.int)),\n",
    "]\n",
    "\n",
    "allowed_accuracy_drop = 0.001 # Allow 0.1%p accuracy drop\n",
    "\n",
    "results_dir = '/path/to/where/we/want/to/store/intermediate/and/final/results'\n",
    "\n",
    "amp_search_algo = AMPSearchAlgo.Binary"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Call AMP API\n",
    "\n",
    "AMP algorithms changes the quantization sim model in place and the final result after running the AMP API is a model which meets the accuracy goal. The algorithm also returns a pareto curve which is a plot between the bit ops and accuracy. Bit ops is calculated by multiplying the MAC's required by the layer, it's output bitwidth & it's parameter bitwidth. Therefore, as we lower the bitwidth for a given layer, bit ops reduces, thereby implying that lesser compute is needed for that layer. \n",
    "\n",
    "Looking at the pareto curve, a user can decide if they want to change the accuracy drop. Note: If a user sets clean_start to False and change the allowed accuracy drop, then AMP will use cached data from results directory so that re-computation is avoided. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from aimet_onnx.mixed_precision import choose_mixed_precision\n",
    "\n",
    "pareto_front_list = choose_mixed_precision(sim, candidates,\n",
    "                                           eval_callback_for_phase1=eval_callback_for_phase1, \n",
    "                                           eval_callback_for_phase2=eval_callback_for_phase2, \n",
    "                                           allowed_accuracy_drop=allowed_accuracy_drop, \n",
    "                                           results_dir=results_dir,\n",
    "                                           clean_start=True, \n",
    "                                           forward_pass_callback=forward_pass_callback,\n",
    "                                           amp_search_algo=amp_search_algo,\n",
    "                                           phase1_optimize= True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "---\n",
    "So we have a Mixed precision model after AMP. Now the next step would be to actually take this model to target. For this purpose, we need to export the model. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.makedirs('./output/', exist_ok=True)\n",
    "sim.export(path='./output/', filename_prefix='resnet18_mixed_precision')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "---\n",
    "## Summary\n",
    "\n",
    "Hope this notebook was useful for you to understand how to use AIMET for performing QAT with range-learning.\n",
    "\n",
    "Few additional resources\n",
    "- Refer to the [AIMET API docs](https://quic.github.io/aimet-pages/AimetDocs/api_docs/index.html) to know more details of the APIs and optional parameters.\n",
    "- Refer to the [other example notebooks](https://github.com/quic/aimet/tree/develop/Examples/torch/quantization)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
